#!/usr/bin/env python3

# Test initial conditions

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect

import configparser
import itertools
from scipy.special import erf
import numpy as np
import os

#requires not make

########################################
# Implementations of BOUT++ functions


def bout_round(x):
    """
    BOUT++ rounding
    """
    return x + 0.5 if x > 0.0 else x - 0.5


def genRand(seed):
    """
    BOUT++ psuedo random number generator

    This PRNG has no memory, i.e. you need to call it with a different
    seed each time
    """
    # Make sure seed is
    if(seed < 0.0):
        seed *= -1

    # Round the seed to get the number of iterations
    niter = int(11 + (23 + bout_round(seed)) % 79)

    # Start x between 0 and 1
    A = 0.01
    B = 1.23456789
    x = (A + np.mod(seed, B)) / (B + 2.*A)

    # Iterate logistic map
    for i in range(niter):
        x = 3.99 * x * (1. - x)

    return x


def ballooning(x, ball_n=3):
    """
    Ballooning function. Currently too tricky to implement
    """
    raise NotImplementedError("ballooning")


def gauss(x, width=1.0):
    """
    Normalised gaussian
    """
    return np.exp(-x**2/(2*width**2)) / np.sqrt(2*np.pi)


def mixmode(x, seed=0.5):
    """
    14 modes with random phases
    """
    result = 0.0
    for i in range(14):
        phase = np.pi * (2.*genRand(seed + i) - 1.)
        result += ((1./(1. + np.abs(i-4.))**2) *
                   np.cos(i * x + phase))
    return result


def heaviside(x):
    """
    Heaviside step function
    """
    return 1 * (x > 0)


def tanhhat(x, width, centre, steepness):
    """
    BOUT++ TanhHat function
    """
    return 0.5*(np.tanh( steepness * (x - (centre - width/2.))) +
                np.tanh(-steepness * (x - (centre + width/2.))))


def atan(x, y=None):
    """
    Resolves to either np.arctan or np.arctan depending on the number of arguments
    """
    if y is not None:
        return np.arctan2(x, y)
    else:
        return np.arctan(x)


def max(*args):
    """
    Maximum of *args at each point
    """
    current = args[0]
    for arg in args:
        current = np.maximum(arg, current)
    return current


def min(*args):
    """
    Minimum of *args at each point
    """
    current = args[0]
    for arg in args:
        current = np.minimum(arg, current)
    return current


def fmod(x, denominator=1.0):
    """
    Modulo operator using fmod convention, (rem in Matlab)
    """
    return np.fmod(x, denominator)


# Rename functions to match BOUT++ naming
# Mostly just alternative names to numpy functions
abs = np.abs
asin = np.arcsin
acos = np.arccos
ballooning = ballooning
cos = np.cos
cosh = np.cosh
exp = np.exp
tanh = np.tanh
H = heaviside
log = np.log
power = np.power
sin = np.sin
sinh = np.sinh
sqrt = np.sqrt
tan = np.tan
TanhHat = tanhhat
pi = np.pi

########################################
# Running the test

# Some parameters
success = True
tolerance = 1e-13
cmd = "./test_initial"
datadir = "data"
inputfile = os.path.join(datadir, "BOUT.inp")

# Read the input file
config = configparser.ConfigParser()
with open(inputfile, "r") as f:
    config.read_file(itertools.chain(['[global]'], f), source=inputfile)

# Find the variables that have a "function" option
varlist = [key for key, values in config.items() if 'function' in values]

# Remove the coordinate arrays
for coord in ["var_x", "var_y", "var_z"]:
    varlist.remove(coord)

print("Making initial conditions test")
shell_safe("make clean")
shell_safe("make > make.log")

nprocs = [1, 2, 3, 4]
for nproc in nprocs:
    status, out = launch_safe(cmd, nproc=nproc, pipe=True, verbose=True)
    with open("run.log.{}".format(nproc), "w") as f:
        f.write(out)

    if status != 0:
        print("=> Could not run test")
        print(status)
        exit(status)

    # Collect the coordinate arrays separately
    x = collect("var_x", xguards=True, yguards=True, path=datadir, info=False)
    y = collect("var_y", xguards=True, yguards=True, path=datadir, info=False)
    z = collect("var_z", xguards=True, yguards=True, path=datadir, info=False)

    # Evaluate the functions
    for var in varlist:
        function = config[var]['function']
        function = function.replace("^", "**")
        if ":" in function:
            print("{} contains reference to variable - not possible to resolve at this time".format(var))
            continue
        try:
            analytic = eval(function)
        except NotImplementedError as err:
            print("{} not implemented, skipping".format(err.args[0]))
        else:
            data = collect(var, xguards=True, yguards=True, path=datadir, info=False)
            E2 = np.sqrt(np.mean((analytic - data)**2))
            if E2 < tolerance:
                success_string = "PASS"
            else:
                if ( var == "mixmode" or var == "mixmode_seed" ) and E2 < 1e-3:
                    import platform
                    arch=platform.machine()
                    if arch=='i686':
                        # This can happen due tue excess precision e.g. on X87 architecture
                        success_string = "WARNING"
                    else:
                        print("This should only happen in i686 with an x87 math architecture.")
                        print("We detected however an %s architecture."%arch)
                        success_string = "FAIL"
                        success = False
                else:
                    success_string = "FAIL"
                    success = False
            print("\tChecking {var:<12}: l-2: {err:.4e} ... {success}".format(var=var, err=E2,
                                                                      success=success_string))

if success:
    print(" => All tests passed")
    exit(0)
else:
    print(" => Some failed tests")
    exit(1)
